# cython: infer_types=True, profile=True
from libc.stdint cimport uintptr_t
from preshed.maps cimport map_init, map_set, map_get, map_clear, map_iter

import warnings

from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA, MORPH
from ..attrs import IDS
from ..structs cimport TokenC
from ..tokens.token cimport Token
from ..tokens.span cimport Span
from ..typedefs cimport attr_t

from ..schemas import TokenPattern
from ..errors import Errors, Warnings


cdef class PhraseMatcher:
    """Efficiently match large terminology lists. While the `Matcher` matches
    sequences based on lists of token descriptions, the `PhraseMatcher` accepts
    match patterns in the form of `Doc` objects.

    DOCS: https://spacy.io/api/phrasematcher
    USAGE: https://spacy.io/usage/rule-based-matching#phrasematcher

    Adapted from FlashText: https://github.com/vi3k6i5/flashtext
    MIT License (see `LICENSE`)
    Copyright (c) 2017 Vikash Singh (vikash.duliajan@gmail.com)
    """

    def __init__(self, Vocab vocab, attr="ORTH", validate=False):
        """Initialize the PhraseMatcher.

        vocab (Vocab): The shared vocabulary.
        attr (int / str): Token attribute to match on.
        validate (bool): Perform additional validation when patterns are added.

        DOCS: https://spacy.io/api/phrasematcher#init
        """
        self.vocab = vocab
        self._callbacks = {}
        self._docs = {}
        self._validate = validate

        self.mem = Pool()
        self.c_map = <MapStruct*>self.mem.alloc(1, sizeof(MapStruct))
        self._terminal_hash = 826361138722620965
        map_init(self.mem, self.c_map, 8)

        if isinstance(attr, (int, long)):
            self.attr = attr
        else:
            if attr is None:
                attr = "ORTH"
            attr = attr.upper()
            if attr == "TEXT":
                attr = "ORTH"
            if attr == "IS_SENT_START":
                attr = "SENT_START"
            if attr.lower() not in TokenPattern().dict():
                raise ValueError(Errors.E152.format(attr=attr))
            self.attr = IDS.get(attr)

    def __len__(self):
        """Get the number of match IDs added to the matcher.

        RETURNS (int): The number of rules.

        DOCS: https://spacy.io/api/phrasematcher#len
        """
        return len(self._callbacks)

    def __contains__(self, key):
        """Check whether the matcher contains rules for a match ID.

        key (str): The match ID.
        RETURNS (bool): Whether the matcher contains rules for this match ID.

        DOCS: https://spacy.io/api/phrasematcher#contains
        """
        return key in self._callbacks

    def __reduce__(self):
        data = (self.vocab, self._docs, self._callbacks, self.attr)
        return (unpickle_matcher, data, None, None)

    def remove(self, key):
        """Remove a rule from the matcher by match ID. A KeyError is raised if
        the key does not exist.

        key (str): The match ID.

        DOCS: https://spacy.io/api/phrasematcher#remove
        """
        if key not in self._docs:
            raise KeyError(key)
        cdef MapStruct* current_node
        cdef MapStruct* terminal_map
        cdef MapStruct* node_pointer
        cdef void* result
        cdef key_t terminal_key
        cdef void* value
        cdef int c_i = 0
        cdef vector[MapStruct*] path_nodes
        cdef vector[key_t] path_keys
        cdef key_t key_to_remove
        for keyword in sorted(self._docs[key], key=lambda x: len(x), reverse=True):
            current_node = self.c_map
            path_nodes.clear()
            path_keys.clear()
            for token in keyword:
                result = map_get(current_node, token)
                if result:
                    path_nodes.push_back(current_node)
                    path_keys.push_back(token)
                    current_node = <MapStruct*>result
                else:
                    # if token is not found, break out of the loop
                    current_node = NULL
                    break
            # remove the tokens from trie node if there are no other
            # keywords with them
            result = map_get(current_node, self._terminal_hash)
            if current_node != NULL and result:
                terminal_map = <MapStruct*>result
                terminal_keys = []
                c_i = 0
                while map_iter(terminal_map, &c_i, &terminal_key, &value):
                    terminal_keys.append(self.vocab.strings[terminal_key])
                # if this is the only remaining key, remove unnecessary paths
                if terminal_keys == [key]:
                    while not path_nodes.empty():
                        node_pointer = path_nodes.back()
                        path_nodes.pop_back()
                        key_to_remove = path_keys.back()
                        path_keys.pop_back()
                        result = map_get(node_pointer, key_to_remove)
                        if node_pointer.filled == 1:
                            map_clear(node_pointer, key_to_remove)
                            self.mem.free(result)
                        else:
                            # more than one key means more than 1 path,
                            # delete not required path and keep the others
                            map_clear(node_pointer, key_to_remove)
                            self.mem.free(result)
                            break
                # otherwise simply remove the key
                else:
                    result = map_get(current_node, self._terminal_hash)
                    if result:
                        map_clear(<MapStruct*>result, self.vocab.strings[key])

        del self._callbacks[key]
        del self._docs[key]

    def add(self, key, docs, *_docs, on_match=None):
        """Add a match-rule to the phrase-matcher. A match-rule consists of: an ID
        key, an on_match callback, and one or more patterns.

        As of spaCy v2.2.2, PhraseMatcher.add supports the future API, which
        makes the patterns the second argument and a list (instead of a variable
        number of arguments). The on_match callback becomes an optional keyword
        argument.

        key (str): The match ID.
        docs (list): List of `Doc` objects representing match patterns.
        on_match (callable): Callback executed on match.
        *_docs (Doc): For backwards compatibility: list of patterns to add
            as variable arguments. Will be ignored if a list of patterns is
            provided as the second argument.

        DOCS: https://spacy.io/api/phrasematcher#add
        """
        if docs is None or hasattr(docs, "__call__"):  # old API
            on_match = docs
            docs = _docs

        _ = self.vocab[key]
        self._callbacks[key] = on_match
        self._docs.setdefault(key, set())

        cdef MapStruct* current_node
        cdef MapStruct* internal_node
        cdef void* result

        if isinstance(docs, Doc):
            raise ValueError(Errors.E179.format(key=key))
        for doc in docs:
            if len(doc) == 0:
                continue
            if isinstance(doc, Doc):
                attrs = (TAG, POS, MORPH, LEMMA, DEP)
                has_annotation = {attr: doc.has_annotation(attr) for attr in attrs}
                for attr in attrs:
                    if self.attr == attr and not has_annotation[attr]:
                        if attr == TAG:
                            pipe = "tagger"
                        elif attr in (POS, MORPH):
                            pipe = "morphologizer or tagger+attribute_ruler"
                        elif attr == LEMMA:
                            pipe = "lemmatizer"
                        elif attr == DEP:
                            pipe = "parser"
                        error_msg = Errors.E155.format(pipe=pipe, attr=self.vocab.strings.as_string(attr))
                        raise ValueError(error_msg)
                if self._validate and any(has_annotation.values()) \
                        and self.attr not in attrs:
                    string_attr = self.vocab.strings[self.attr]
                    warnings.warn(Warnings.W012.format(key=key, attr=string_attr))
                keyword = self._convert_to_array(doc)
            else:
                keyword = doc
            self._docs[key].add(tuple(keyword))

            current_node = self.c_map
            for token in keyword:
                if token == self._terminal_hash:
                    warnings.warn(Warnings.W021)
                    break
                result = <MapStruct*>map_get(current_node, token)
                if not result:
                    internal_node = <MapStruct*>self.mem.alloc(1, sizeof(MapStruct))
                    map_init(self.mem, internal_node, 8)
                    map_set(self.mem, current_node, token, internal_node)
                    result = internal_node
                current_node = <MapStruct*>result
            result = <MapStruct*>map_get(current_node, self._terminal_hash)
            if not result:
                internal_node = <MapStruct*>self.mem.alloc(1, sizeof(MapStruct))
                map_init(self.mem, internal_node, 8)
                map_set(self.mem, current_node, self._terminal_hash, internal_node)
                result = internal_node
            map_set(self.mem, <MapStruct*>result, self.vocab.strings[key], NULL)

    def __call__(self, object doclike, *, as_spans=False):
        """Find all sequences matching the supplied patterns on the `Doc`.

        doclike (Doc or Span): The document to match over.
        as_spans (bool): Return Span objects with labels instead of (match_id,
            start, end) tuples.
        RETURNS (list): A list of `(match_id, start, end)` tuples,
            describing the matches. A match tuple describes a span
            `doc[start:end]`. The `match_id` is an integer. If as_spans is set
            to True, a list of Span objects is returned.

        DOCS: https://spacy.io/api/phrasematcher#call
        """
        matches = []
        if doclike is None or len(doclike) == 0:
            # if doc is empty or None just return empty list
            return matches
        if isinstance(doclike, Doc):
            doc = doclike
            start_idx = 0
            end_idx = len(doc)
        elif isinstance(doclike, Span):
            doc = doclike.doc
            start_idx = doclike.start
            end_idx = doclike.end
        else:
            raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doclike).__name__))

        cdef vector[SpanC] c_matches
        self.find_matches(doc, start_idx, end_idx, &c_matches)
        for i in range(c_matches.size()):
            matches.append((c_matches[i].label, c_matches[i].start, c_matches[i].end))
        for i, (ent_id, start, end) in enumerate(matches):
            on_match = self._callbacks.get(self.vocab.strings[ent_id])
            if on_match is not None:
                on_match(self, doc, i, matches)
        if as_spans:
            return [Span(doc, start, end, label=key) for key, start, end in matches]
        else:
            return matches

    cdef void find_matches(self, Doc doc, int start_idx, int end_idx, vector[SpanC] *matches) nogil:
        cdef MapStruct* current_node = self.c_map
        cdef int start = 0
        cdef int idx = start_idx
        cdef int idy = start_idx
        cdef key_t key
        cdef void* value
        cdef int i = 0
        cdef SpanC ms
        cdef void* result
        while idx < end_idx:
            start = idx
            token = Token.get_struct_attr(&doc.c[idx], self.attr)
            # look for sequences from this position
            result = map_get(current_node, token)
            if result:
                current_node = <MapStruct*>result
                idy = idx + 1
                while idy < end_idx:
                    result = map_get(current_node, self._terminal_hash)
                    if result:
                        i = 0
                        while map_iter(<MapStruct*>result, &i, &key, &value):
                            ms = make_spanstruct(key, start, idy)
                            matches.push_back(ms)
                    inner_token = Token.get_struct_attr(&doc.c[idy], self.attr)
                    result = map_get(current_node, inner_token)
                    if result:
                        current_node = <MapStruct*>result
                        idy += 1
                    else:
                        break
                else:
                    # end of doc reached
                    result = map_get(current_node, self._terminal_hash)
                    if result:
                        i = 0
                        while map_iter(<MapStruct*>result, &i, &key, &value):
                            ms = make_spanstruct(key, start, idy)
                            matches.push_back(ms)
            current_node = self.c_map
            idx += 1

    def pipe(self, stream, batch_size=1000, return_matches=False, as_tuples=False):
        """Match a stream of documents, yielding them in turn. Deprecated as of
        spaCy v3.0.
        """
        warnings.warn(Warnings.W105.format(matcher="PhraseMatcher"), DeprecationWarning)
        if as_tuples:
            for doc, context in stream:
                matches = self(doc)
                if return_matches:
                    yield ((doc, matches), context)
                else:
                    yield (doc, context)
        else:
            for doc in stream:
                matches = self(doc)
                if return_matches:
                    yield (doc, matches)
                else:
                    yield doc

    def _convert_to_array(self, Doc doc):
        return [Token.get_struct_attr(&doc.c[i], self.attr) for i in range(len(doc))]


def unpickle_matcher(vocab, docs, callbacks, attr):
    matcher = PhraseMatcher(vocab, attr=attr)
    for key, specs in docs.items():
        callback = callbacks.get(key, None)
        matcher.add(key, specs, on_match=callback)
    return matcher


cdef SpanC make_spanstruct(attr_t label, int start, int end) nogil:
    cdef SpanC spanc
    spanc.label = label
    spanc.start = start
    spanc.end = end
    return spanc
